{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 5.12 稠密连接网络（DenseNet）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.4.0\n",
      "cuda\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import sys\n",
    "sys.path.append(\"..\") \n",
    "import d2lzh_pytorch as d2l\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "print(torch.__version__)\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5.12.1 稠密块"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def conv_block(in_channels, out_channels):\n",
    "    blk = nn.Sequential(nn.BatchNorm2d(in_channels), \n",
    "                        nn.ReLU(),\n",
    "                        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))\n",
    "    return blk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DenseBlock(nn.Module):\n",
    "    def __init__(self, num_convs, in_channels, out_channels):\n",
    "        super(DenseBlock, self).__init__()\n",
    "        net = []\n",
    "        for i in range(num_convs):\n",
    "            in_c = in_channels + i * out_channels\n",
    "            net.append(conv_block(in_c, out_channels))\n",
    "        self.net = nn.ModuleList(net)\n",
    "        self.out_channels = in_channels + num_convs * out_channels # 计算输出通道数\n",
    "\n",
    "    def forward(self, X):\n",
    "        for blk in self.net:\n",
    "            Y = blk(X)\n",
    "            X = torch.cat((X, Y), dim=1)  # 在通道维上将输入和输出连结\n",
    "        return X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 23, 8, 8])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "blk = DenseBlock(2, 3, 10)\n",
    "X = torch.rand(4, 3, 8, 8)\n",
    "Y = blk(X)\n",
    "Y.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5.12.2 过渡层"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def transition_block(in_channels, out_channels):\n",
    "    blk = nn.Sequential(\n",
    "            nn.BatchNorm2d(in_channels), \n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(in_channels, out_channels, kernel_size=1),\n",
    "            nn.AvgPool2d(kernel_size=2, stride=2))\n",
    "    return blk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 10, 4, 4])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "blk = transition_block(23, 10)\n",
    "blk(Y).shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5.12.3 DenseNet模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = nn.Sequential(\n",
    "        nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),\n",
    "        nn.BatchNorm2d(64), \n",
    "        nn.ReLU(),\n",
    "        nn.MaxPool2d(kernel_size=3, stride=2, padding=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_channels, growth_rate = 64, 32  # num_channels为当前的通道数\n",
    "num_convs_in_dense_blocks = [4, 4, 4, 4]\n",
    "\n",
    "for i, num_convs in enumerate(num_convs_in_dense_blocks):\n",
    "    DB = DenseBlock(num_convs, num_channels, growth_rate)\n",
    "    net.add_module(\"DenseBlosk_%d\" % i, DB)\n",
    "    # 上一个稠密块的输出通道数\n",
    "    num_channels = DB.out_channels\n",
    "    # 在稠密块之间加入通道数减半的过渡层\n",
    "    if i != len(num_convs_in_dense_blocks) - 1:\n",
    "        net.add_module(\"transition_block_%d\" % i, transition_block(num_channels, num_channels // 2))\n",
    "        num_channels = num_channels // 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "net.add_module(\"BN\", nn.BatchNorm2d(num_channels))\n",
    "net.add_module(\"relu\", nn.ReLU())\n",
    "net.add_module(\"global_avg_pool\", d2l.GlobalAvgPool2d()) # GlobalAvgPool2d的输出: (Batch, num_channels, 1, 1)\n",
    "net.add_module(\"fc\", nn.Sequential(d2l.FlattenLayer(), nn.Linear(num_channels, 10))) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0  output shape:\t torch.Size([1, 64, 48, 48])\n",
      "1  output shape:\t torch.Size([1, 64, 48, 48])\n",
      "2  output shape:\t torch.Size([1, 64, 48, 48])\n",
      "3  output shape:\t torch.Size([1, 64, 24, 24])\n",
      "DenseBlosk_0  output shape:\t torch.Size([1, 192, 24, 24])\n",
      "transition_block_0  output shape:\t torch.Size([1, 96, 12, 12])\n",
      "DenseBlosk_1  output shape:\t torch.Size([1, 224, 12, 12])\n",
      "transition_block_1  output shape:\t torch.Size([1, 112, 6, 6])\n",
      "DenseBlosk_2  output shape:\t torch.Size([1, 240, 6, 6])\n",
      "transition_block_2  output shape:\t torch.Size([1, 120, 3, 3])\n",
      "DenseBlosk_3  output shape:\t torch.Size([1, 248, 3, 3])\n",
      "BN  output shape:\t torch.Size([1, 248, 3, 3])\n",
      "relu  output shape:\t torch.Size([1, 248, 3, 3])\n",
      "global_avg_pool  output shape:\t torch.Size([1, 248, 1, 1])\n",
      "fc  output shape:\t torch.Size([1, 10])\n"
     ]
    }
   ],
   "source": [
    "X = torch.rand((1, 1, 96, 96))\n",
    "for name, layer in net.named_children():\n",
    "    X = layer(X)\n",
    "    print(name, ' output shape:\\t', X.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5.12.4 获取数据并训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training on  cuda\n",
      "epoch 1, loss 0.0020, train acc 0.834, test acc 0.749, time 27.7 sec\n",
      "epoch 2, loss 0.0011, train acc 0.900, test acc 0.824, time 25.5 sec\n",
      "epoch 3, loss 0.0009, train acc 0.913, test acc 0.839, time 23.8 sec\n",
      "epoch 4, loss 0.0008, train acc 0.921, test acc 0.889, time 24.9 sec\n",
      "epoch 5, loss 0.0008, train acc 0.929, test acc 0.884, time 24.3 sec\n"
     ]
    }
   ],
   "source": [
    "batch_size = 256\n",
    "# 如出现“out of memory”的报错信息，可减小batch_size或resize\n",
    "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)\n",
    "\n",
    "lr, num_epochs = 0.001, 5\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
    "d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [default]",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
